import copy
import sys
from unittest.mock import Mock, patch

import pytest
from click.exceptions import ClickException

import ray.tests.aws.utils.helpers as helpers
import ray.tests.aws.utils.stubs as stubs
from ray.autoscaler._private.aws.config import (
    DEFAULT_AMI,
    _configure_subnet,
    _get_subnets_or_die,
    bootstrap_aws,
    log_to_cli,
)
from ray.autoscaler._private.aws.node_provider import AWSNodeProvider
from ray.autoscaler._private.providers import _get_node_provider
from ray.tests.aws.utils.constants import (
    AUX_SG,
    AUX_SUBNET,
    CUSTOM_IN_BOUND_RULES,
    DEFAULT_CLUSTER_NAME,
    DEFAULT_INSTANCE_PROFILE,
    DEFAULT_KEY_PAIR,
    DEFAULT_LT,
    DEFAULT_SG,
    DEFAULT_SG_AUX_SUBNET,
    DEFAULT_SG_DUAL_GROUP_RULES,
    DEFAULT_SG_WITH_NAME,
    DEFAULT_SG_WITH_NAME_AND_RULES,
    DEFAULT_SG_WITH_RULES,
    DEFAULT_SG_WITH_RULES_AUX_SUBNET,
    DEFAULT_SUBNET,
)


def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
    """
    This test validates that when bootstrap_aws populates the SubnetIds field,
    all of the subnets used belong to the same VPC, and that a SecurityGroup
    in that VPC is correctly configured.

    Also validates that head IAM role is correctly filled.
    """
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # Add a response with a thousand subnets all in different VPCs.
    # After filtering, only subnet in one particular VPC should remain.
    # Thus SubnetIds for each available node type should end up as
    # being length-one lists after the bootstrap_config.
    stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

    # describe the subnet in use while determining its vpc
    stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to create a security group on the VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG)
    # expect new security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG],
    )

    # given no existing default security group inbound rules...
    # expect to authorize all default inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_RULES,
    )

    # expect another call to describe the above security group while checking
    # a second time if it has ip_permissions set ("if not sg.ip_permissions")
    stubs.describe_an_sg_2(
        ec2_client_stub,
        DEFAULT_SG_WITH_RULES,
    )

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file("example-full.yaml")
    _get_subnets_or_die.cache_clear()

    # We've filtered down to only one subnet id -- only one of the thousand
    # subnets generated by ec2.subnets.all() belongs to the right VPC.
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]
        assert node_config["SecurityGroupIds"] == [DEFAULT_SG["GroupId"]]


@pytest.mark.parametrize(
    "correct_az",
    [True, False],
)
def test_create_sg_different_vpc_same_rules(
    iam_client_stub, ec2_client_stub, correct_az: bool
):
    # use default stubs to skip ahead to security group configuration
    stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)

    default_subnet = copy.deepcopy(DEFAULT_SUBNET)
    if not correct_az:
        default_subnet["AvailabilityZone"] = "us-west-2b"

    # given head and worker nodes with custom subnets defined...
    # expect to second describe the head subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [default_subnet])
    # expect to first describe the worker subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [AUX_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to first create a security group on the worker node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG_AUX_SUBNET)
    # expect new worker security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [AUX_SUBNET["VpcId"]],
        [DEFAULT_SG_AUX_SUBNET],
    )
    # expect to second create a security group on the head node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG)
    # expect new head security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG],
    )

    # given no existing default head security group inbound rules...
    # expect to authorize all default head inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_DUAL_GROUP_RULES,
    )
    # given no existing default worker security group inbound rules...
    # expect to authorize all default worker inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_RULES_AUX_SUBNET,
    )

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    error = None
    try:
        config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
    except ClickException as e:
        error = e

    _get_subnets_or_die.cache_clear()

    if not correct_az:
        assert isinstance(error, ClickException), "Did not get a ClickException!"
        iam_client_stub._queue.clear()
        ec2_client_stub._queue.clear()
        return

    # expect the bootstrapped config to show different head and worker security
    # groups residing on different subnets
    for node_type_key, node_type in config["available_node_types"].items():
        node_config = node_type["node_config"]
        security_group_ids = node_config["SecurityGroupIds"]
        subnet_ids = node_config["SubnetIds"]
        if node_type_key == config["head_node_type"]:
            assert security_group_ids == [DEFAULT_SG["GroupId"]]
            assert subnet_ids == [DEFAULT_SUBNET["SubnetId"]]
        else:
            assert security_group_ids == [AUX_SG["GroupId"]]
            assert subnet_ids == [AUX_SUBNET["SubnetId"]]

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_client_stub):
    # use default stubs to skip ahead to security group configuration
    stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)

    # expect to describe the head subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to create a security group on the head node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME)
    # expect new head security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG_WITH_NAME],
    )

    # given custom existing default head security group inbound rules...
    # expect to authorize both default and custom inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_NAME_AND_RULES,
    )

    # given the prior modification to the head security group...
    # expect the next read of a head security group property to reload it
    stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)

    _get_subnets_or_die.cache_clear()
    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file("example-security-group.yaml")

    # expect the bootstrapped config to have the custom security group...
    # name and in bound rules
    assert (
        config["provider"]["security_group"]["GroupName"]
        == DEFAULT_SG_WITH_NAME_AND_RULES["GroupName"]
    )
    assert (
        config["provider"]["security_group"]["IpPermissions"] == CUSTOM_IN_BOUND_RULES
    )

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_subnet_given_head_and_worker_sg(iam_client_stub, ec2_client_stub):
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # list a security group and a thousand subnets in different vpcs
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

    config = helpers.bootstrap_aws_example_config_file(
        "example-head-and-worker-security-group.yaml"
    )

    # check that just the single subnet in the right vpc is filled
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


# Parametrize across multiple regions, since default AMI is different in each
@pytest.mark.parametrize(
    "iam_client_stub,ec2_client_stub,region",
    [3 * (region,) for region in DEFAULT_AMI],
    indirect=["iam_client_stub", "ec2_client_stub"],
)
def test_fills_out_amis_and_iam(iam_client_stub, ec2_client_stub, region):
    # Set up expected key pair for specific region
    region_key_pair = DEFAULT_KEY_PAIR.copy()
    region_key_pair["KeyName"] = DEFAULT_KEY_PAIR["KeyName"].replace(
        "us-west-2", region
    )

    # Setup stubs to mock out boto3
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(
        ec2_client_stub, region=region, expected_key_pair=region_key_pair
    )
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    config = helpers.load_aws_example_config_file("example-full.yaml")
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    del head_node_config["ImageId"]
    del worker_node_config["ImageId"]

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    config["provider"]["region"] = region

    defaults_filled = bootstrap_aws(config)

    ami = DEFAULT_AMI.get(defaults_filled.get("provider", {}).get("region"))

    for node_type in defaults_filled["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config.get("ImageId") == ami

    # Correctly configured IAM role
    assert defaults_filled["head_node"]["IamInstanceProfile"] == {
        "Arn": DEFAULT_INSTANCE_PROFILE["Arn"]
    }
    # Workers of the head's type do not get the IAM role.
    head_type = config["head_node_type"]
    assert (
        "IamInstanceProfile" not in defaults_filled["available_node_types"][head_type]
    )

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_iam_already_configured(iam_client_stub, ec2_client_stub):
    """
    Checks that things work as expected when IAM role is supplied by user.
    """
    stubs.configure_key_pair_default(ec2_client_stub)
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    config = helpers.load_aws_example_config_file("example-full.yaml")
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    head_node_config["IamInstanceProfile"] = "mock_profile"

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    defaults_filled = bootstrap_aws(config)
    filled_head = defaults_filled["available_node_types"]["ray.head.default"][
        "node_config"
    ]
    assert filled_head["IamInstanceProfile"] == "mock_profile"
    assert "IamInstanceProfile" not in defaults_filled["head_node"]

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
    """
    Test AWS Bootstrap logic when config being bootstrapped has the
    following properties:

    (1) auth config does not specify ssh key path
    (2) available_node_types is provided
    (3) security group name and ip permissions set in provider field
    (4) Available node types have SubnetIds field set and this
        field is of form SubnetIds: [subnet-xxxxx].
        Both node types specify the same subnet-xxxxx.

    Tests creation of a security group and key pair under these conditions.
    """

    # Generate a config of the desired form.
    subnet_id = DEFAULT_SUBNET["SubnetId"]

    # security group info to go in provider field
    provider_data = helpers.load_aws_example_config_file("example-security-group.yaml")[
        "provider"
    ]

    # a multi-node-type config -- will add head/worker stuff and security group
    # info to this.
    base_config = helpers.load_aws_example_config_file("example-full.yaml")

    config = copy.deepcopy(base_config)
    # Add security group data
    config["provider"] = provider_data
    # Add head and worker fields.
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]
    head_node_config["SubnetIds"] = [subnet_id]
    worker_node_config["SubnetIds"] = [subnet_id]

    # Generate stubs
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # Only one of these (the one specified in the available_node_types)
    # is in the correct vpc.
    # This list of subnets is generated by the ec2.subnets.all() call
    # and then ignored, since available_node_types already specify
    # subnet_ids.
    stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

    # The rest of the stubbing logic is copied from
    # test_create_sg_with_custom_inbound_rules_and_name.

    # expect to describe the head subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to create a security group on the head node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME)
    # expect new head security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG_WITH_NAME],
    )

    # given custom existing default head security group inbound rules...
    # expect to authorize both default and custom inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_NAME_AND_RULES,
    )

    # given the prior modification to the head security group...
    # expect the next read of a head security group property to reload it
    stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)

    _get_subnets_or_die.cache_clear()

    # given our mocks and the config as input...
    # expect the config to be validated and bootstrapped successfully
    bootstrapped_config = helpers.bootstrap_aws_config(config)

    # expect the bootstrapped config to have the custom security group...
    # name and in bound rules
    assert (
        bootstrapped_config["provider"]["security_group"]["GroupName"]
        == DEFAULT_SG_WITH_NAME_AND_RULES["GroupName"]
    )
    assert (
        bootstrapped_config["provider"]["security_group"]["IpPermissions"]
        == CUSTOM_IN_BOUND_RULES
    )

    # Confirming correct security group got filled for head and workers
    sg_id = DEFAULT_SG["GroupId"]
    for node_type in bootstrapped_config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["SecurityGroupIds"] == [sg_id]

    # Confirming boostrap config updates available node types with
    # default KeyName
    for node_type in bootstrapped_config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["KeyName"] == DEFAULT_KEY_PAIR["KeyName"]

    # Confirm security group is in the right VPC.
    # (Doesn't really confirm anything except for the structure of this test
    # data.)
    bootstrapped_head_type = bootstrapped_config["head_node_type"]
    bootstrapped_types = bootstrapped_config["available_node_types"]
    bootstrapped_head_config = bootstrapped_types[bootstrapped_head_type]["node_config"]
    assert DEFAULT_SG["VpcId"] == DEFAULT_SUBNET["VpcId"]
    assert DEFAULT_SUBNET["SubnetId"] == bootstrapped_head_config["SubnetIds"][0]

    # ssh private key filled in
    assert "ssh_private_key" in bootstrapped_config["auth"]

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_missing_keyname(iam_client_stub, ec2_client_stub):
    config = helpers.load_aws_example_config_file("example-full.yaml")
    config["auth"]["ssh_private_key"] = "/path/to/private/key"
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    # Setup stubs to mock out boto3. Should fail on assertion after
    # checking KeyName/UserData.
    stubs.configure_iam_role_default(iam_client_stub)

    missing_user_data_config = copy.deepcopy(config)
    with pytest.raises(AssertionError):
        # Config specified ssh_private_key, but missing KeyName/UserData in
        # node configs
        bootstrap_aws(missing_user_data_config)

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    # Set UserData for both node configs
    head_node_config["UserData"] = {"someKey": "someValue"}
    worker_node_config["UserData"] = {"someKey": "someValue"}

    # Stubs to mock out boto3. Should no longer fail on assertion
    # and go on to describe security groups + configure subnet
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    # Should work without error now that UserData is set
    bootstrap_aws(config)

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_log_to_cli(iam_client_stub, ec2_client_stub):
    config = helpers.load_aws_example_config_file("example-full.yaml")

    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    config = helpers.bootstrap_aws_config(config)

    # Only side-effect is to print logs to cli, called just to
    # check that it runs without error
    log_to_cli(config)
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_network_interfaces(
    ec2_client_stub,
    iam_client_stub,
    ec2_client_stub_fail_fast,
    ec2_client_stub_max_retries,
):

    # use default stubs to skip ahead to subnet configuration
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # given the security groups associated with our network interfaces...
    sgids = ["sg-00000000", "sg-11111111", "sg-22222222", "sg-33333333"]
    security_groups = []
    suffix = 0
    for sgid in sgids:
        sg = copy.deepcopy(DEFAULT_SG)
        sg["GroupName"] += f"-{suffix}"
        sg["GroupId"] = sgid
        security_groups.append(sg)
        suffix += 1
    # expect to describe all security groups to ensure they share the same VPC
    stubs.describe_sgs_by_id(ec2_client_stub, sgids, security_groups)

    # use a default stub to skip subnet configuration
    stubs.configure_subnet_default(ec2_client_stub)
    stubs.describe_subnets_echo(
        ec2_client_stub,
        [DEFAULT_SUBNET, {**DEFAULT_SUBNET, "SubnetId": "subnet-11111111"}],
    )
    stubs.describe_subnets_echo(
        ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-22222222"}]
    )
    stubs.describe_subnets_echo(
        ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-33333333"}]
    )

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file(
        "example-network-interfaces.yaml"
    )

    # instantiate a new node provider
    new_provider = _get_node_provider(
        config["provider"],
        DEFAULT_CLUSTER_NAME,
        False,
    )

    for name, node_type in config["available_node_types"].items():
        node_cfg = node_type["node_config"]
        tags = helpers.node_provider_tags(config, name)
        # given our bootstrapped node config as input to create a new node...
        # expect to first describe all stopped instances that could be reused
        stubs.describe_instances_with_any_filter_consumer(ec2_client_stub_max_retries)
        # given no stopped EC2 instances to reuse...
        # expect to create new nodes with the given network interface config
        stubs.run_instances_with_network_interfaces_consumer(
            ec2_client_stub_fail_fast,
            node_cfg["NetworkInterfaces"],
        )
        new_provider.create_node(node_cfg, tags, 1)

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()
    ec2_client_stub_fail_fast.assert_no_pending_responses()
    ec2_client_stub_max_retries.assert_no_pending_responses()


def test_network_interface_conflict_keys():
    # If NetworkInterfaces are defined, SubnetId and SecurityGroupIds
    # can't be specified in the same node type config.
    conflict_kv_pairs = [
        ("SubnetId", "subnet-0000000"),
        ("SubnetIds", ["subnet-0000000", "subnet-1111111"]),
        ("SecurityGroupIds", ["sg-1234abcd", "sg-dcba4321"]),
    ]
    expected_error_msg = (
        "If NetworkInterfaces are defined, subnets and "
        "security groups must ONLY be given in each "
        "NetworkInterface."
    )
    for conflict_kv_pair in conflict_kv_pairs:
        config = helpers.load_aws_example_config_file("example-network-interfaces.yaml")
        head_name = config["head_node_type"]
        head_node_cfg = config["available_node_types"][head_name]["node_config"]
        head_node_cfg[conflict_kv_pair[0]] = conflict_kv_pair[1]
        with pytest.raises(ValueError, match=expected_error_msg):
            helpers.bootstrap_aws_config(config)


def test_network_interface_missing_subnet():
    # If NetworkInterfaces are defined, each must have a subnet ID
    expected_error_msg = (
        "NetworkInterfaces are defined but at least one is "
        "missing a subnet. Please ensure all interfaces "
        "have a subnet assigned."
    )
    config = helpers.load_aws_example_config_file("example-network-interfaces.yaml")
    for name, node_type in config["available_node_types"].items():
        node_cfg = node_type["node_config"]
        for network_interface_cfg in node_cfg["NetworkInterfaces"]:
            network_interface_cfg.pop("SubnetId")
            with pytest.raises(ValueError, match=expected_error_msg):
                helpers.bootstrap_aws_config(config)


def test_network_interface_missing_security_group():
    # If NetworkInterfaces are defined, each must have security groups
    expected_error_msg = (
        "NetworkInterfaces are defined but at least one is "
        "missing a security group. Please ensure all "
        "interfaces have a security group assigned."
    )
    config = helpers.load_aws_example_config_file("example-network-interfaces.yaml")
    for name, node_type in config["available_node_types"].items():
        node_cfg = node_type["node_config"]
        for network_interface_cfg in node_cfg["NetworkInterfaces"]:
            network_interface_cfg.pop("Groups")
            with pytest.raises(ValueError, match=expected_error_msg):
                helpers.bootstrap_aws_config(config)


def test_launch_templates(
    ec2_client_stub, ec2_client_stub_fail_fast, ec2_client_stub_max_retries
):

    # given the launch template associated with our default head node type...
    # expect to first describe the default launch template by ID
    stubs.describe_launch_template_versions_by_id_default(ec2_client_stub, ["$Latest"])
    # given the launch template associated with our default worker node type...
    # expect to next describe the same default launch template by name
    stubs.describe_launch_template_versions_by_name_default(ec2_client_stub, ["2"])
    # use default stubs to skip ahead to subnet configuration
    stubs.configure_key_pair_default(ec2_client_stub)

    # given the security groups associated with our launch template...
    sgids = [DEFAULT_SG["GroupId"]]
    security_groups = [DEFAULT_SG]
    # expect to describe all security groups to ensure they share the same VPC
    stubs.describe_sgs_by_id(ec2_client_stub, sgids, security_groups)

    # use a default stub to skip subnet configuration
    stubs.configure_subnet_default(ec2_client_stub)

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file("example-launch-templates.yaml")

    # instantiate a new node provider
    new_provider = _get_node_provider(
        config["provider"],
        DEFAULT_CLUSTER_NAME,
        False,
    )

    max_count = 1
    for name, node_type in config["available_node_types"].items():
        # given our bootstrapped node config as input to create a new node...
        # expect to first describe all stopped instances that could be reused
        stubs.describe_instances_with_any_filter_consumer(ec2_client_stub_max_retries)
        # given no stopped EC2 instances to reuse...
        # expect to create new nodes with the given launch template config
        node_cfg = node_type["node_config"]
        stubs.run_instances_with_launch_template_consumer(
            ec2_client_stub_fail_fast,
            config,
            node_cfg,
            name,
            DEFAULT_LT["LaunchTemplateData"],
            max_count,
        )
        tags = helpers.node_provider_tags(config, name)
        new_provider.create_node(node_cfg, tags, max_count)

    ec2_client_stub.assert_no_pending_responses()
    ec2_client_stub_fail_fast.assert_no_pending_responses()
    ec2_client_stub_max_retries.assert_no_pending_responses()


@pytest.mark.parametrize("num_on_demand_nodes", [0, 1001, 9999])
@pytest.mark.parametrize("num_spot_nodes", [0, 1001, 9999])
@pytest.mark.parametrize("stop", [True, False])
def test_terminate_nodes(num_on_demand_nodes, num_spot_nodes, stop):
    # This node makes sure that we stop or terminate all the nodes we're
    # supposed to stop or terminate when we call "terminate_nodes". This test
    # alse makes sure that we don't try to stop or terminate too many nodes in
    # a single EC2 request. By default, only 1000 nodes can be
    # stopped/terminated in one request. To terminate more nodes, we must break
    # them up into multiple smaller requests.
    #
    # "num_on_demand_nodes" is the number of on-demand nodes to stop or
    #   terminate.
    # "num_spot_nodes" is the number of on-demand nodes to terminate.
    # "stop" is True if we want to stop nodes, and False to terminate nodes.
    #   Note that spot instances are always terminated, even if "stop" is True.

    # Generate a list of unique instance ids to terminate
    on_demand_nodes = {"i-{:017d}".format(i) for i in range(num_on_demand_nodes)}
    spot_nodes = {
        "i-{:017d}".format(i + num_on_demand_nodes) for i in range(num_spot_nodes)
    }
    node_ids = list(on_demand_nodes.union(spot_nodes))

    with patch("ray.autoscaler._private.aws.node_provider.make_ec2_resource"):
        provider = AWSNodeProvider(
            provider_config={"region": "nowhere", "cache_stopped_nodes": stop},
            cluster_name="default",
        )

    # "_get_cached_node" is used by the AWSNodeProvider to determine whether a
    # node is a spot instance or an on-demand instance.
    def mock_get_cached_node(node_id):
        result = Mock()
        result.spot_instance_request_id = (
            "sir-08b93456" if node_id in spot_nodes else ""
        )
        return result

    provider._get_cached_node = mock_get_cached_node

    provider.terminate_nodes(node_ids)

    stop_calls = provider.ec2.meta.client.stop_instances.call_args_list
    terminate_calls = provider.ec2.meta.client.terminate_instances.call_args_list

    nodes_to_stop = set()
    nodes_to_terminate = spot_nodes

    if stop:
        nodes_to_stop.update(on_demand_nodes)
    else:
        nodes_to_terminate.update(on_demand_nodes)

    for calls, nodes_to_include_in_call in (stop_calls, nodes_to_stop), (
        terminate_calls,
        nodes_to_terminate,
    ):
        nodes_included_in_call = set()
        for call in calls:
            assert len(call[1]["InstanceIds"]) <= provider.max_terminate_nodes
            nodes_included_in_call.update(call[1]["InstanceIds"])

        assert nodes_to_include_in_call == nodes_included_in_call


def test_retry_get_node():
    """
    This tests _get_node() retries `ec2.instances.filter` if the first call returns an empty list.
    This is important because the EC2 API is eventually consistent, and it may take some time for the
    instance to be available after it has been launched.
    """

    with patch("ray.autoscaler._private.aws.node_provider.make_ec2_resource"):
        provider = AWSNodeProvider(
            provider_config={"region": "nowhere"},
            cluster_name="default",
        )

    attempts = 0
    fake_instance = {"id": "i-1234567890abcdef0"}

    def mock_filter(*args, **kwargs):
        nonlocal attempts
        if kwargs.get("InstanceIds") == [fake_instance["id"]]:
            attempts += 1
            if attempts > 1:
                return [fake_instance]
        return []

    provider.ec2.instances.filter.side_effect = mock_filter

    assert provider._get_node(fake_instance["id"]) == fake_instance
    assert attempts == 2


def test_use_subnets_ordered_by_az(ec2_client_stub):
    """
    This test validates that when bootstrap_aws populates the SubnetIds field,
    the subnets are ordered the same way as availability zones.

    """
    # Add a response with a twenty subnets round-robined across the 4 AZs in
    # `us-west-2` (a,b,c,d). At the end we should only have 15 subnets, ordered
    # first from `us-west-2c`, then `us-west-2d`, then `us-west-2a`.
    stubs.describe_twenty_subnets_in_different_azs(ec2_client_stub)

    base_config = helpers.load_aws_example_config_file("example-full.yaml")
    base_config["provider"]["availability_zone"] = "us-west-2c,us-west-2d,us-west-2a"
    config = _configure_subnet(base_config)

    # We've filtered down to only subnets in 2c, 2d & 2a
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert len(node_config["SubnetIds"]) == 15
        offsets = [int(s.split("-")[1]) % 4 for s in node_config["SubnetIds"]]
        assert set(offsets[:5]) == {2}, "First 5 should be in us-west-2c"
        assert set(offsets[5:10]) == {3}, "Next 5 should be in us-west-2d"
        assert set(offsets[10:15]) == {0}, "Last 5 should be in us-west-2a"


def test_cloudwatch_dashboard_creation(cloudwatch_client_stub, ssm_client_stub):
    # create test cluster node IDs and an associated cloudwatch helper
    node_id = "i-abc"
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to create a cluster CloudWatch Dashboard...
    # expect to make a call to create a dashboard for each node in the cluster
    stubs.put_cluster_dashboard_success(
        cloudwatch_client_stub,
        cloudwatch_helper,
    )

    # given our mocks and the example CloudWatch Dashboard config as input...
    # expect a cluster CloudWatch Dashboard to be created successfully
    cloudwatch_helper._put_cloudwatch_dashboard()
    # expect no pending responses left in the CloudWatch client stub queue
    cloudwatch_client_stub.assert_no_pending_responses()


def test_cloudwatch_alarm_creation(cloudwatch_client_stub, ssm_client_stub):
    # create test cluster node IDs and an associated cloudwatch helper
    node_id = "i-abc"
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to update a cluster CloudWatch Alarm Config without any
    # change...
    # expect the stored the CloudWatch Alarm Config is same as local config
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "alarm"
    )
    stubs.get_param_ssm_same(
        ssm_client_stub, cw_ssm_param_name, cloudwatch_helper, "alarm"
    )

    # given a directive to create cluster CloudWatch alarms...
    # expect to make a call to create alarms for each node in the cluster
    stubs.put_cluster_alarms_success(cloudwatch_client_stub, cloudwatch_helper)

    # given our mocks and the example CloudWatch Alarm config as input...
    # expect cluster alarms to be created successfully
    cloudwatch_helper._put_cloudwatch_alarm()

    # expect no pending responses left in the CloudWatch client stub queue
    cloudwatch_client_stub.assert_no_pending_responses()


def test_cloudwatch_agent_update_without_change_head_node(
    ssm_client_stub, ec2_client_stub
):
    # create test cluster head node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = True
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)

    # given a directive to update a cluster CloudWatch Agent Config without any
    # change...
    # expect the stored the CloudWatch Agent Config is same as local config
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "agent"
    )
    stubs.get_param_ssm_same(
        ssm_client_stub, cw_ssm_param_name, cloudwatch_helper, "agent"
    )

    # given our mocks and the same cloudwatch agent config as input...
    # expect no update performed on CloudWatch Agent Config
    cloudwatch_helper._update_cloudwatch_config("agent", is_head_node)


def test_cloudwatch_agent_update_with_change_head_node(
    ec2_client_stub, ssm_client_stub
):
    # create test cluster head node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = True
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)
    # given a directive to update a cluster CloudWatch Agent Config with new
    # changes...
    # expect the stored the CloudWatch Agent Config is different from local
    # config
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "agent"
    )
    stubs.get_param_ssm_different(ssm_client_stub, cw_ssm_param_name)

    # given an updated CloudWatch Agent Config file...
    # expect to store the new CloudWatch Agent config as an SSM parameter
    cmd_id = stubs.put_parameter_cloudwatch_config(
        ssm_client_stub, cloudwatch_helper.cluster_name, "agent"
    )

    # given an updated CloudWatch Agent Config file...
    # expect to update the node tag equal to updated config file sha1 hash
    # to reflect the changes in config file
    stubs.update_hash_tag_success(ec2_client_stub, node_id, "agent", cloudwatch_helper)
    # given that updated CloudWatch Agent Config is put to Parameter Store...
    # expect to send an SSM command to restart CloudWatch Agent on all nodes
    cmd_id = stubs.send_command_stop_cwa(ssm_client_stub, node_id)
    # given a SSM command to stop CloudWatch Agent sent to all nodes...
    # expect to wait for the command to complete successfully on every node
    stubs.list_command_invocations_success(ssm_client_stub, node_id, cmd_id)
    cmd_id = stubs.send_command_start_cwa(ssm_client_stub, node_id, cw_ssm_param_name)
    # given a SSM command to start CloudWatch Agent sent to all nodes...
    # expect to wait for the command to complete successfully on every node
    stubs.list_command_invocations_success(ssm_client_stub, node_id, cmd_id)

    # given our mocks and the example CloudWatch Agent config as input...
    # expect CloudWatch Agent configured to use updated file on each cluster
    # node successfully
    cloudwatch_helper._update_cloudwatch_config("agent", is_head_node)

    # expect no pending responses left in client stub queues
    ec2_client_stub.assert_no_pending_responses()
    ssm_client_stub.assert_no_pending_responses()


def test_cloudwatch_agent_update_with_change_worker_node(
    ec2_client_stub, ssm_client_stub
):
    # create test cluster worker node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = False
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)

    # given a directive to update a cluster CloudWatch Agent Config with new
    # changes...
    # expect the stored the CloudWatch Agent Config is different from local
    # config
    stubs.get_head_node_config_hash_different(
        ec2_client_stub, "agent", cloudwatch_helper, node_id
    )
    stubs.get_cur_node_config_hash_different(ec2_client_stub, "agent", node_id)

    # given an updated CloudWatch Agent Config file...
    # expect to update the node tag equal to updated config file sha1 hash
    # to reflect the changes in config file
    stubs.update_hash_tag_success(ec2_client_stub, node_id, "agent", cloudwatch_helper)
    # given that updated CloudWatch Agent Config is put to Parameter Store...
    # expect to send an SSM command to restart CloudWatch Agent on all nodes
    cmd_id = stubs.send_command_stop_cwa(ssm_client_stub, node_id)
    # given a SSM command to stop CloudWatch Agent sent to all nodes...
    # expect to wait for the command to complete successfully on every node
    stubs.list_command_invocations_success(ssm_client_stub, node_id, cmd_id)
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "agent"
    )
    cmd_id = stubs.send_command_start_cwa(ssm_client_stub, node_id, cw_ssm_param_name)
    # given a SSM command to start CloudWatch Agent sent to all nodes...
    # expect to wait for the command to complete successfully on every node
    stubs.list_command_invocations_success(ssm_client_stub, node_id, cmd_id)

    # given our mocks and the example CloudWatch Agent config as input...
    # expect CloudWatch Agent configured to use updated file on each cluster
    # node successfully
    cloudwatch_helper._update_cloudwatch_config("agent", is_head_node)

    # expect no pending responses left in client stub queues
    ec2_client_stub.assert_no_pending_responses()
    ssm_client_stub.assert_no_pending_responses()


def test_cloudwatch_dashboard_update_head_node(
    ec2_client_stub, ssm_client_stub, cloudwatch_client_stub
):
    # create test cluster head node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = True
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)

    # given a directive to update a cluster CloudWatch Dashboard Config
    # with new changes...
    # expect the stored the CloudWatch Dashboard Config is different from local
    # config
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "dashboard"
    )
    stubs.get_param_ssm_different(ssm_client_stub, cw_ssm_param_name)

    # given an updated CloudWatch Dashboard Config file...
    # expect to store the new CloudWatch Dashboard config as an SSM parameter
    stubs.put_parameter_cloudwatch_config(
        ssm_client_stub, cloudwatch_helper.cluster_name, "dashboard"
    )

    # given an updated CloudWatch Dashboard Config file...
    # expect to update the node tag equal to updated config file sha1 hash
    # to reflect the changes in config file
    stubs.update_hash_tag_success(
        ec2_client_stub, node_id, "dashboard", cloudwatch_helper
    )

    # given a directive to create a cluster CloudWatch dashboard...
    # expect to make a call to create a dashboard for each node in the cluster
    stubs.put_cluster_dashboard_success(
        cloudwatch_client_stub,
        cloudwatch_helper,
    )
    # given our mocks and the example CloudWatch Dashboard config as input...
    # expect CloudWatch Dashboard configured to use updated file
    # on each cluster node successfully
    cloudwatch_helper._update_cloudwatch_config("dashboard", is_head_node)

    # expect no pending responses left in client stub queues
    ec2_client_stub.assert_no_pending_responses()
    ssm_client_stub.assert_no_pending_responses()


def test_cloudwatch_dashboard_update_worker_node(
    ec2_client_stub, ssm_client_stub, cloudwatch_client_stub
):
    # create test cluster worker node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = False
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)

    # given a directive to update a cluster CloudWatch Dashboard Config
    # with new changes...
    # expect the stored the CloudWatch Dashboard Config is different from local
    # config
    stubs.get_head_node_config_hash_different(
        ec2_client_stub, "dashboard", cloudwatch_helper, node_id
    )
    stubs.get_cur_node_config_hash_different(ec2_client_stub, "dashboard", node_id)

    # given an updated CloudWatch Dashboard Config file...
    # expect to update the node tag equal to updated config file sha1 hash
    # to reflect the changes in config file
    stubs.update_hash_tag_success(
        ec2_client_stub, node_id, "dashboard", cloudwatch_helper
    )

    # given our mocks and the example CloudWatch Dashboard config as input...
    # expect CloudWatch Dashboard configured to use updated file
    # on each cluster node successfully
    cloudwatch_helper._update_cloudwatch_config("dashboard", is_head_node)

    # expect no pending responses left in client stub queues
    ec2_client_stub.assert_no_pending_responses()
    ssm_client_stub.assert_no_pending_responses()


def test_cloudwatch_alarm_update_head_node(
    ec2_client_stub, ssm_client_stub, cloudwatch_client_stub
):
    # create test cluster head node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = True
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)

    # given a directive to update a cluster CloudWatch Alarm Config with new
    # changes...
    # expect the stored the CloudWatch Alarm Config is different from local
    # config
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "alarm"
    )
    stubs.get_param_ssm_different(ssm_client_stub, cw_ssm_param_name)

    # given an updated CloudWatch Alarm Config file...
    # expect to store the new CloudWatch Alarm config as an SSM parameter
    stubs.put_parameter_cloudwatch_config(
        ssm_client_stub, cloudwatch_helper.cluster_name, "alarm"
    )

    # given an updated CloudWatch Alarm Config file...
    # expect to update the node tag equal to updated config file sha1 hash
    # to reflect the changes in config file
    stubs.update_hash_tag_success(ec2_client_stub, node_id, "alarm", cloudwatch_helper)
    stubs.get_param_ssm_same(
        ssm_client_stub, cw_ssm_param_name, cloudwatch_helper, "alarm"
    )

    # given a directive to create cluster  CloudWatch Alarms...
    # expect to make a call to create alarms for each node in the cluster
    stubs.put_cluster_alarms_success(cloudwatch_client_stub, cloudwatch_helper)

    # given our mocks and the example  CloudWatch Alarm config as input...
    # expect  CloudWatch Alarm configured to use updated file on each cluster
    # node successfully
    cloudwatch_helper._update_cloudwatch_config("alarm", is_head_node)

    # expect no pending responses left in client stub queues
    ec2_client_stub.assert_no_pending_responses()
    ssm_client_stub.assert_no_pending_responses()


def test_cloudwatch_alarm_update_worker_node(
    ec2_client_stub, ssm_client_stub, cloudwatch_client_stub
):
    # create test cluster worker node ID and an associated cloudwatch helper
    node_id = "i-abc"
    is_head_node = False
    cloudwatch_helper = helpers.get_cloudwatch_helper(node_id)

    # given a directive to check for the Unified CloudWatch Agent status...
    # expect CloudWatch Agent is installed
    stubs.get_ec2_cwa_installed_tag_true(ec2_client_stub, node_id)

    # given a directive to update a cluster CloudWatch Alarm Config with new
    # changes...
    # expect the stored the CloudWatch Alarm Config is different from local
    # config
    cw_ssm_param_name = helpers.get_ssm_param_name(
        cloudwatch_helper.cluster_name, "alarm"
    )

    # given a directive to update a cluster CloudWatch Alarm Config with new
    # changes...
    # expect the stored the CloudWatch Alarm Config is different from local
    # config
    stubs.get_head_node_config_hash_different(
        ec2_client_stub, "alarm", cloudwatch_helper, node_id
    )
    stubs.get_cur_node_config_hash_different(ec2_client_stub, "alarm", node_id)

    # given an updated CloudWatch Alarm Config file...
    # expect to update the node tag equal to updated config file sha1 hash
    # to reflect the changes in config file
    stubs.update_hash_tag_success(ec2_client_stub, node_id, "alarm", cloudwatch_helper)
    stubs.get_param_ssm_same(
        ssm_client_stub, cw_ssm_param_name, cloudwatch_helper, "alarm"
    )

    # given a directive to create cluster CloudWatch Alarms...
    # expect to make a call to create alarms for each node in the cluster
    stubs.put_cluster_alarms_success(cloudwatch_client_stub, cloudwatch_helper)
    # given our mocks and the example CloudWatch Alarm config as input...
    # expect CloudWatch Alarm configured to use updated file on each cluster
    # node successfully
    cloudwatch_helper._update_cloudwatch_config("alarm", is_head_node)

    # expect no pending responses left in client stub queues
    ec2_client_stub.assert_no_pending_responses()
    ssm_client_stub.assert_no_pending_responses()


if __name__ == "__main__":
    sys.exit(pytest.main(["-v", __file__]))
